/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2019-2022. All rights reserved.
 * Description: graph_verify_util.cpp
 */

#include "framework/graph/core/infershape/graph_verify_util.h"

#include <vector>

#include "infra/base/assertion.h"

// inc/framework
#include "framework/infra/log/log.h"
#include "framework/graph/core/cgraph/compute_graph.h"
#include "framework/graph/core/cgraph/graph_list_walker.h"
#include "framework/graph/core/node/node.h"
#include "framework/graph/core/node/node_spec.h"
#include "framework/graph/core/node/node_functor.h"
#include "framework/graph/core/infershape/op_ir_ctx.h"
#include "framework/graph/core/infershape/op_ir_func_factory.h"
#include "framework/graph/core/infershape/op_ir_facade.h"
#include "framework/graph/utils/checker/graph_checker.h"

namespace ge {
namespace {
class GraphVerifier : private NodeFunctor {
public:
    explicit GraphVerifier(const ComputeGraph& graph) : graph_(graph) {}
    ~GraphVerifier() override = default;

    hiai::Status VerifyGraphNodes()
    {
        HIAI_EXPECT_EXEC(graph_.ROLE(GraphListWalker).WalkAllNodes(ToVisitor()));

        PrintErrInfo();

        return infos_.empty() ? hiai::SUCCESS : hiai::FAILURE;
    }

private:
    hiai::Status Visit(Node& node) override
    {
        HIAI_EXPECT_TRUE(node.ROLE(NodeSpec).OpDesc().IsDimValid());

        OpVerifyFunc verifyFunc = OpIRFuncFactory::Instance()->GetVerifyFunc(node);
        if (verifyFunc != nullptr) {
            RunVerifyFunc(node, verifyFunc);
        }

        return hiai::SUCCESS;
    }

    void RunVerifyFunc(Node& node, OpVerifyFunc& func)
    {
        OpIRFacade opIRFacade(node);
        InferContext inferContext;
        inferContext.SetOpIRFacade(opIRFacade);
        if (func(inferContext) != hiai::SUCCESS) {
            infos_.push_back(inferContext.GetVerifyInfo());
        }
    }

    void PrintErrInfo() const
    {
        for (const VerifyInfo& info : infos_) {
            for (const std::string& msg : info.messages) {
                FMK_LOGE("[op:%s type:%s] Verify failed, %s", info.name.c_str(), info.type.c_str(), msg.c_str());
            }
        }
    }

private:
    const ComputeGraph& graph_;
    std::vector<VerifyInfo> infos_;
};
} // namespace

hiai::Status GraphVerifyUtil::Verify(const ComputeGraph& graph, bool isNeedVerifySubGraph)
{
    HIAI_EXPECT_TRUE(GraphChecker::IsInputsFullyLinked(graph, isNeedVerifySubGraph));

    GraphVerifier verifier(graph);
    return verifier.VerifyGraphNodes();
}
} // namespace ge
